{
 "cells": [
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%env LLM_API_KEY=替换为自己的Qwen API Key，打分用\n",
    "%env LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1"
   ]
  },
  {
   "metadata": {},
   "cell_type": "code",
   "outputs": [],
   "execution_count": null,
   "source": [
    "%%capture --no-stderr\n",
    "!pip install -U langchain langchain_community langchain_openai pypdf sentence_transformers chromadb shutil openpyxl FlagEmbedding"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:49.646845Z",
     "iopub.status.busy": "2024-10-14T12:24:49.646141Z",
     "iopub.status.idle": "2024-10-14T12:24:52.557242Z",
     "shell.execute_reply": "2024-10-14T12:24:52.556771Z",
     "shell.execute_reply.started": "2024-10-14T12:24:49.646780Z"
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/anaconda3/lib/python3.10/site-packages/sentence_transformers/cross_encoder/CrossEncoder.py:11: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
      "  from tqdm.autonotebook import tqdm, trange\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "langchain                     0.2.10\n",
      "langchain_core                0.2.28\n",
      "langchain_community           0.2.9\n",
      "pypdf                         4.3.1\n",
      "sentence_transformers         3.0.1\n",
      "chromadb                      0.5.4\n"
     ]
    }
   ],
   "source": [
    "import langchain, langchain_community, pypdf, sentence_transformers, chromadb, langchain_core\n",
    "\n",
    "for module in (langchain, langchain_core, langchain_community, pypdf, sentence_transformers, chromadb):\n",
    "    print(f\"{module.__name__:<30}{module.__version__}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:52.558093Z",
     "iopub.status.busy": "2024-10-14T12:24:52.557710Z",
     "iopub.status.idle": "2024-10-14T12:24:52.560015Z",
     "shell.execute_reply": "2024-10-14T12:24:52.559697Z",
     "shell.execute_reply.started": "2024-10-14T12:24:52.558080Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:52.560548Z",
     "iopub.status.busy": "2024-10-14T12:24:52.560426Z",
     "iopub.status.idle": "2024-10-14T12:24:52.573774Z",
     "shell.execute_reply": "2024-10-14T12:24:52.573335Z",
     "shell.execute_reply.started": "2024-10-14T12:24:52.560536Z"
    }
   },
   "outputs": [],
   "source": [
    "expr_version = 'retrieval_v9_parent_document_retriever'\n",
    "\n",
    "preprocess_output_dir = os.path.join(os.path.pardir, 'outputs', 'v1_20240713')\n",
    "expr_dir = os.path.join(os.path.pardir, 'experiments', expr_version)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 读取文档"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:52.574938Z",
     "iopub.status.busy": "2024-10-14T12:24:52.574802Z",
     "iopub.status.idle": "2024-10-14T12:24:54.091047Z",
     "shell.execute_reply": "2024-10-14T12:24:54.090581Z",
     "shell.execute_reply.started": "2024-10-14T12:24:52.574916Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_community.document_loaders import PyPDFLoader\n",
    "\n",
    "loader = PyPDFLoader(os.path.join(os.path.pardir, 'data', '2024全球经济金融展望报告.pdf'))\n",
    "documents = loader.load()\n",
    "\n",
    "qa_df = pd.read_excel(os.path.join(preprocess_output_dir, 'question_answer.xlsx'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:54.091870Z",
     "iopub.status.busy": "2024-10-14T12:24:54.091636Z",
     "iopub.status.idle": "2024-10-14T12:24:54.096084Z",
     "shell.execute_reply": "2024-10-14T12:24:54.095774Z",
     "shell.execute_reply.started": "2024-10-14T12:24:54.091858Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "53"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(documents)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 文档切分"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 现有切分方法"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-11T09:30:02.953004Z",
     "iopub.status.busy": "2024-10-11T09:30:02.952819Z",
     "iopub.status.idle": "2024-10-11T09:30:02.955603Z",
     "shell.execute_reply": "2024-10-11T09:30:02.955088Z",
     "shell.execute_reply.started": "2024-10-11T09:30:02.952989Z"
    }
   },
   "source": [
    "现有切分方法这部分不是必须的，但后续在评估检索性能时使用到了文档片段的uuid，为了能够评估检索效果，此处还是加入现有切分方法，通过关联知识片段的方法把uuid关联到ParentDocumentRetrieval，这样"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:54.180185Z",
     "iopub.status.busy": "2024-10-14T12:24:54.180059Z",
     "iopub.status.idle": "2024-10-14T12:24:54.185324Z",
     "shell.execute_reply": "2024-10-14T12:24:54.184940Z",
     "shell.execute_reply.started": "2024-10-14T12:24:54.180173Z"
    }
   },
   "outputs": [],
   "source": [
    "from uuid import uuid4\n",
    "import os\n",
    "import pickle\n",
    "\n",
    "def split_docs(documents, filepath, chunk_size=400, chunk_overlap=40, seperators=['\\n\\n\\n', '\\n\\n'], force_split=False):\n",
    "    if os.path.exists(filepath) and not force_split:\n",
    "        print('found cache, restoring...')\n",
    "        return pickle.load(open(filepath, 'rb'))\n",
    "\n",
    "    splitter = RecursiveCharacterTextSplitter(\n",
    "        chunk_size=chunk_size,\n",
    "        chunk_overlap=chunk_overlap,\n",
    "        separators=seperators\n",
    "    )\n",
    "    split_docs = splitter.split_documents(documents)\n",
    "    for chunk in split_docs:\n",
    "        chunk.metadata['uuid'] = str(uuid4())\n",
    "\n",
    "    pickle.dump(split_docs, open(filepath, 'wb'))\n",
    "\n",
    "    return split_docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:54.185842Z",
     "iopub.status.busy": "2024-10-14T12:24:54.185720Z",
     "iopub.status.idle": "2024-10-14T12:24:54.192047Z",
     "shell.execute_reply": "2024-10-14T12:24:54.191623Z",
     "shell.execute_reply.started": "2024-10-14T12:24:54.185831Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "found cache, restoring...\n"
     ]
    }
   ],
   "source": [
    "splitted_docs = split_docs(documents, os.path.join(preprocess_output_dir, 'split_docs.pkl'), chunk_size=500, chunk_overlap=50)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 检索"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 准备检索器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:54.192555Z",
     "iopub.status.busy": "2024-10-14T12:24:54.192435Z",
     "iopub.status.idle": "2024-10-14T12:24:54.203692Z",
     "shell.execute_reply": "2024-10-14T12:24:54.203232Z",
     "shell.execute_reply.started": "2024-10-14T12:24:54.192544Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "device: cuda\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "\n",
    "from langchain.embeddings import HuggingFaceBgeEmbeddings\n",
    "from langchain_community.vectorstores import Chroma\n",
    "from langchain.storage import InMemoryStore\n",
    "\n",
    "model_path = 'BAAI/bge-large-zh-v1.5'\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "print(f'device: {device}')\n",
    "\n",
    "def get_embeddings(model_path):\n",
    "    embeddings = HuggingFaceBgeEmbeddings(\n",
    "        model_name=model_path,\n",
    "        model_kwargs={'device': device},\n",
    "        encode_kwargs={'normalize_embeddings': True},\n",
    "        query_instruction='为这个句子生成表示以用于检索相关文章：'\n",
    "    )\n",
    "    return embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:54.204242Z",
     "iopub.status.busy": "2024-10-14T12:24:54.204120Z",
     "iopub.status.idle": "2024-10-14T12:24:54.829231Z",
     "shell.execute_reply": "2024-10-14T12:24:54.828782Z",
     "shell.execute_reply.started": "2024-10-14T12:24:54.204231Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain.retrievers import ParentDocumentRetriever\n",
    "\n",
    "parent_splitter = RecursiveCharacterTextSplitter(\n",
    "    chunk_size=2000,\n",
    "    chunk_overlap=100,\n",
    "    separators=['\\n\\n\\n', '\\n\\n', '\\n']\n",
    ")\n",
    "child_splitter = RecursiveCharacterTextSplitter(\n",
    "    chunk_size=500,\n",
    "    chunk_overlap=50,\n",
    "    separators=['\\n\\n\\n', '\\n\\n', '\\n']\n",
    ")\n",
    "\n",
    "vectorstore = Chroma(\n",
    "    collection_name='split_parents', embedding_function=get_embeddings(model_path)\n",
    ")\n",
    "store = InMemoryStore()\n",
    "\n",
    "retriever = ParentDocumentRetriever(\n",
    "    vectorstore=vectorstore,\n",
    "    docstore=store,\n",
    "    child_splitter=child_splitter,\n",
    "    parent_splitter=parent_splitter,\n",
    "    search_kwargs={'k': 3}\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:24:54.831169Z",
     "iopub.status.busy": "2024-10-14T12:24:54.830986Z",
     "iopub.status.idle": "2024-10-14T12:25:00.375045Z",
     "shell.execute_reply": "2024-10-14T12:25:00.374526Z",
     "shell.execute_reply.started": "2024-10-14T12:24:54.831156Z"
    }
   },
   "outputs": [],
   "source": [
    "retriever.add_documents(documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:25:00.375782Z",
     "iopub.status.busy": "2024-10-14T12:25:00.375619Z",
     "iopub.status.idle": "2024-10-14T12:25:00.398852Z",
     "shell.execute_reply": "2024-10-14T12:25:00.398405Z",
     "shell.execute_reply.started": "2024-10-14T12:25:00.375770Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "234"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(vectorstore.similarity_search('2023年全球经济增长的特点是什么？')[0].page_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:25:00.399491Z",
     "iopub.status.busy": "2024-10-14T12:25:00.399364Z",
     "iopub.status.idle": "2024-10-14T12:25:00.629116Z",
     "shell.execute_reply": "2024-10-14T12:25:00.628541Z",
     "shell.execute_reply.started": "2024-10-14T12:25:00.399479Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "679"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(retriever.invoke('2023年全球经济增长的特点是什么？')[0].page_content)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:25:00.630068Z",
     "iopub.status.busy": "2024-10-14T12:25:00.629728Z",
     "iopub.status.idle": "2024-10-14T12:25:00.642303Z",
     "shell.execute_reply": "2024-10-14T12:25:00.641748Z",
     "shell.execute_reply.started": "2024-10-14T12:25:00.630051Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain.llms import Ollama\n",
    "\n",
    "ollama_llm = Ollama(\n",
    "    model='qwen2:7b-instruct',\n",
    "    base_url='http://localhost:11434'\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:25:00.643039Z",
     "iopub.status.busy": "2024-10-14T12:25:00.642875Z",
     "iopub.status.idle": "2024-10-14T12:25:00.646440Z",
     "shell.execute_reply": "2024-10-14T12:25:00.645901Z",
     "shell.execute_reply.started": "2024-10-14T12:25:00.643024Z"
    }
   },
   "outputs": [],
   "source": [
    "def rag(question):\n",
    "    prompt_tmpl = \"\"\"\n",
    "你是一个金融分析师，擅长根据所获取的信息片段，对问题进行分析和推理。\n",
    "你的任务是根据所获取的信息片段（<<<<context>>><<<</context>>>之间的内容）回答问题。\n",
    "回答保持简洁，不必重复问题，不要添加描述性解释和与答案无关的任何内容。\n",
    "已知信息：\n",
    "<<<<context>>>\n",
    "{{knowledge}}\n",
    "<<<</context>>>\n",
    "\n",
    "问题：{{question}}\n",
    "请回答：\n",
    "\"\"\".strip()\n",
    "\n",
    "    chunks = retriever.invoke(question)\n",
    "    prompt = prompt_tmpl.replace('{{knowledge}}', '\\n\\n'.join([doc.page_content for doc in chunks])).replace('{{question}}', question)\n",
    "\n",
    "    return ollama_llm.invoke(prompt), chunks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:25:00.647341Z",
     "iopub.status.busy": "2024-10-14T12:25:00.647132Z",
     "iopub.status.idle": "2024-10-14T12:25:10.776040Z",
     "shell.execute_reply": "2024-10-14T12:25:10.775677Z",
     "shell.execute_reply.started": "2024-10-14T12:25:00.647325Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2023年全球经济增长的特点是“复苏+分化”，即全球经济在经历一段时间的调整后开始复苏，但不同地区、国家间的增长速度和表现存在显著差异。发达经济体增速明显放缓，而新兴经济体增长也面临挑战与机遇并存。\n",
      "\n",
      "具体来看：\n",
      "\n",
      "1. **发达国家**：\n",
      "   - **美国经济**：未受加息明显冲击，出现超预期增长。服务领域的消费支出稳定增长，并且受益于就业市场稳健、劳动者实际收入增加和政府政策支持（如《通胀削减法案》和《芯片与科学法案》），推动了经济增长。\n",
      "   - **欧元区和英国**：经济增长显著放缓，三季度GDP环比增速由正转负。\n",
      "\n",
      "2. **新兴经济体**：\n",
      "   - **东南亚地区**（例如菲律宾、印度尼西亚）：居民消费支出增长强劲，对GDP有较高拉动作用。经济增长率相对较高。\n",
      "   - **中东地区**：经济增长依靠非能源领域，但由于高基数效应，增速明显减弱。沙特的GDP增长率降低至1.2%。\n",
      "   - **拉美新兴经济体**（如巴西、墨西哥）：增长表现好于部分国家，GDP增速超过3%，但阿根廷面临通胀与负增长问题。\n",
      "   - **非洲新兴经济体**：整体增长疲软，南非GDP同比增长为年内最高增速。各国通胀走势分化，南非通胀率相对较低，埃及和尼日利亚物价水平上涨。\n",
      "\n",
      "全球贸易形势的好转有望带动东南亚经济体出口恢复，预计2024年全球经济增速将略高于2023年的4%，但仍低于前一年度的水平，约为2.5%。\n"
     ]
    }
   ],
   "source": [
    "print(rag('2023年全球经济增长的特点是什么？')[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:25:10.776661Z",
     "iopub.status.busy": "2024-10-14T12:25:10.776531Z",
     "iopub.status.idle": "2024-10-14T12:25:10.781491Z",
     "shell.execute_reply": "2024-10-14T12:25:10.781146Z",
     "shell.execute_reply.started": "2024-10-14T12:25:10.776648Z"
    }
   },
   "outputs": [],
   "source": [
    "from tqdm.auto import tqdm\n",
    "\n",
    "prediction_df = qa_df[qa_df['dataset'] == 'test'][['uuid', 'question', 'qa_type', 'answer']].rename(columns={'answer': 'ref_answer'})\n",
    "\n",
    "def predict(prediction_df):\n",
    "    prediction_df = prediction_df.copy()\n",
    "    answer_dict = {}\n",
    "\n",
    "    for idx, row in tqdm(prediction_df.iterrows(), total=len(prediction_df)):\n",
    "        uuid = row['uuid']\n",
    "        question = row['question']\n",
    "        answer, chunks = rag(question)\n",
    "\n",
    "        answer_dict[question] = {\n",
    "            'uuid': uuid,\n",
    "            'ref_answer': row['ref_answer'],\n",
    "            'gen_answer': answer,\n",
    "            'chunks': chunks\n",
    "        }\n",
    "    prediction_df.loc[:, 'gen_answer'] = prediction_df['question'].apply(lambda q: answer_dict[q]['gen_answer'])\n",
    "    prediction_df.loc[:, 'chunks'] = prediction_df['question'].apply(lambda q: answer_dict[q]['chunks'])\n",
    "\n",
    "    return prediction_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:25:10.782196Z",
     "iopub.status.busy": "2024-10-14T12:25:10.781915Z",
     "iopub.status.idle": "2024-10-14T12:29:01.892917Z",
     "shell.execute_reply": "2024-10-14T12:29:01.892412Z",
     "shell.execute_reply.started": "2024-10-14T12:25:10.782184Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6afcaf308eb24aee81efb931bbae4f89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_df = predict(prediction_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 评估"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:29:01.893709Z",
     "iopub.status.busy": "2024-10-14T12:29:01.893564Z",
     "iopub.status.idle": "2024-10-14T12:29:02.207574Z",
     "shell.execute_reply": "2024-10-14T12:29:02.207099Z",
     "shell.execute_reply.started": "2024-10-14T12:29:01.893695Z"
    }
   },
   "outputs": [],
   "source": [
    "from langchain_openai import ChatOpenAI\n",
    "import time\n",
    "\n",
    "judge_llm = ChatOpenAI(\n",
    "    api_key=os.environ['LLM_API_KEY'],\n",
    "    base_url=os.environ['LLM_BASE_URL'],\n",
    "    model_name='qwen2-72b-instruct',\n",
    "    temperature=0\n",
    ")\n",
    "\n",
    "def evaluate(prediction_df):\n",
    "    \"\"\"\n",
    "    对预测结果进行打分\n",
    "    :param prediction_df: 预测结果，需要包含问题，参考答案，生成的答案，列名分别为question, ref_answer, gen_answer\n",
    "    :return 打分模型原始返回结果\n",
    "    \"\"\"\n",
    "    prompt_tmpl = \"\"\"\n",
    "你是一个经济学博士，现在我有一系列问题，有一个助手已经对这些问题进行了回答，你需要参照参考答案，评价这个助手的回答是否正确，仅回复“是”或“否”即可，不要带其他描述性内容或无关信息。\n",
    "问题：\n",
    "<question>\n",
    "{{question}}\n",
    "</question>\n",
    "\n",
    "参考答案：\n",
    "<ref_answer>\n",
    "{{ref_answer}}\n",
    "</ref_answer>\n",
    "\n",
    "助手回答：\n",
    "<gen_answer>\n",
    "{{gen_answer}}\n",
    "</gen_answer>\n",
    "请评价：\n",
    "    \"\"\"\n",
    "    results = []\n",
    "\n",
    "    for _, row in tqdm(prediction_df.iterrows(), total=len(prediction_df)):\n",
    "        question = row['question']\n",
    "        ref_answer = row['ref_answer']\n",
    "        gen_answer = row['gen_answer']\n",
    "\n",
    "        prompt = prompt_tmpl.replace('{{question}}', question).replace('{{ref_answer}}', str(ref_answer)).replace('{{gen_answer}}', gen_answer).strip()\n",
    "        result = judge_llm.invoke(prompt).content\n",
    "        results.append(result)\n",
    "\n",
    "        time.sleep(1)\n",
    "    return results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:29:02.208394Z",
     "iopub.status.busy": "2024-10-14T12:29:02.208116Z",
     "iopub.status.idle": "2024-10-14T12:31:35.962846Z",
     "shell.execute_reply": "2024-10-14T12:31:35.960853Z",
     "shell.execute_reply.started": "2024-10-14T12:29:02.208381Z"
    }
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "26e7e2e2a1434dc5b7c20864c7235187",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "  0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "pred_df['raw_score'] = evaluate(pred_df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:31:35.965994Z",
     "iopub.status.busy": "2024-10-14T12:31:35.965284Z",
     "iopub.status.idle": "2024-10-14T12:31:35.980541Z",
     "shell.execute_reply": "2024-10-14T12:31:35.978308Z",
     "shell.execute_reply.started": "2024-10-14T12:31:35.965927Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array(['是', '否'], dtype=object)"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df['raw_score'].unique()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:31:35.983581Z",
     "iopub.status.busy": "2024-10-14T12:31:35.982843Z",
     "iopub.status.idle": "2024-10-14T12:31:35.995298Z",
     "shell.execute_reply": "2024-10-14T12:31:35.993182Z",
     "shell.execute_reply.started": "2024-10-14T12:31:35.983515Z"
    }
   },
   "outputs": [],
   "source": [
    "pred_df['score'] = (pred_df['raw_score'] == '是').astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:31:35.998120Z",
     "iopub.status.busy": "2024-10-14T12:31:35.997421Z",
     "iopub.status.idle": "2024-10-14T12:31:36.013051Z",
     "shell.execute_reply": "2024-10-14T12:31:36.010780Z",
     "shell.execute_reply.started": "2024-10-14T12:31:35.998043Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0.67"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df['score'].mean()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:31:36.016342Z",
     "iopub.status.busy": "2024-10-14T12:31:36.015614Z",
     "iopub.status.idle": "2024-10-14T12:31:36.029833Z",
     "shell.execute_reply": "2024-10-14T12:31:36.029274Z",
     "shell.execute_reply.started": "2024-10-14T12:31:36.016274Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>qa_type</th>\n",
       "      <th>score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>detailed</td>\n",
       "      <td>0.655914</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>large_context</td>\n",
       "      <td>0.857143</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         qa_type     score\n",
       "0       detailed  0.655914\n",
       "1  large_context  0.857143"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pred_df[['qa_type', 'score']].groupby('qa_type').mean().reset_index()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "execution": {
     "iopub.execute_input": "2024-10-14T12:31:36.030445Z",
     "iopub.status.busy": "2024-10-14T12:31:36.030312Z",
     "iopub.status.idle": "2024-10-14T12:31:36.240578Z",
     "shell.execute_reply": "2024-10-14T12:31:36.238231Z",
     "shell.execute_reply.started": "2024-10-14T12:31:36.030433Z"
    }
   },
   "outputs": [],
   "source": [
    "pred_df.to_excel(os.path.join(expr_dir, f'{expr_version}_prediction.xlsx'), index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
